import torch
from torch import nn
import torch.nn.functional as F
from torch_geometric.utils import scatter

from Utils import MLP
import pdb

class GNRF(nn.Module):
    def __init__(self, args):
        super().__init__()
        self.args = args
        self.edge_index = None
        self.damping = args.damping
        self.edgenet = args.edgenet
        self.beta = float(args.beta)
        if self.edgenet:
            self.mlp_1 = MLP(2 * args.num_hid, args.num_hid, args.num_hid, 2, args.dropout)
            if args.channel_curv: 
                self.mlp_2 = MLP(2 * args.num_hid, args.num_hid, args.num_hid, 2, args.dropout)
            else:
                self.mlp_2 = MLP(2 * args.num_hid, args.num_hid, 1, 2, args.dropout)
        else:
            self.a = nn.Parameter(torch.tensor(0.5))

    def set_edges(self, edge_index):
        self.edge_index = edge_index

    def forward(self, t, H):
        if self.damping:
            norm = torch.norm(H, p=2, dim=1, keepdim=True) + 1e-8
            H = H / norm
        H_i = H[self.edge_index[0]]
        H_j = H[self.edge_index[1]]
        if self.edgenet:
            curv = self.curvature(H_i, H_j)
        else:
            curv = torch.clamp(self.a, 1e-8, 1)
        
        
        norm_H_i= torch.norm(H_i, dim=1, keepdim=True)
        norm_H_j= torch.norm(H_j, dim=1, keepdim=True)
        dot_ij = (H_i * H_j).sum(dim=1, keepdim=True)
        uu =  -H_j + H_i * dot_ij 
        vv =  -H_i + H_j * dot_ij
        
        #denom = (uu.norm(dim=1, keepdim=True)**2 + vv.norm(dim=1, keepdim=True)**2)
        
        #d_cos = dot_ij / (norm_H_i * norm_H_j)
        ww = torch.exp(-(1.0-dot_ij))   
        
        #curv =  curv - (self.beta / ww) * torch.sum(uu * vv, dim=1, keepdim=True)
        #delta_d = delta_d / denom
        #H_edge_i = delta_d * uu
        #H_edge_j = delta_d * vv
        
        delta_d=(self.beta / ww) * torch.sum(uu * vv, dim=1, keepdim=True)
        if self.damping:
            H_edge = curv * (H_j - dot_ij * H_i)
        
        else:
            H_edge = curv * (H_j - H_i)
        #pdb.set_trace()
        #mean_H_edge = H_edge.mean()
        #mean_delta_d = delta_d.mean()
        #print(f"Mean H_edge:    {H_edge.mean().item():.4f}")
        #print(f"Mean delta_d:   {delta_d.mean().item():.4f}")
        H_edge=H_edge-delta_d
        #pdb.set_trace()
        H = scatter(H_edge, self.edge_index[0], reduce="mean")
        #H = scatter(H_edge_j, self.edge_index[1], reduce="mean")
        if self.damping:
            norm = torch.norm(H, p=2, dim=1, keepdim=True) + 1e-8
            H =  H / norm
        return H

    def curvature(self, H_i, H_j):
        curv = torch.cat((H_i, H_j), dim=1)
        curv = F.relu(self.mlp_1(curv))
        curv = scatter(curv, self.edge_index[0])
        curv = torch.cat((curv[self.edge_index[0]], curv[self.edge_index[1]]), dim=1)
        curv = self.mlp_2(curv)
        return curv
